from datam import *
from model import *
from train import *

if __name__ == '__main__':



    model = Model.load_from_checkpoint("example.ckpt")
    model.eval()
    data_mnist = DataM("../../../../data/",32,0)
    data_mnist.setup()
    data,label = next(iter(data_mnist.test_dataloader()))

    prediction = model(data)
    print(prediction)

    trainer_clone = pl.Trainer(max_epochs=3, gpus=1)
    model_clone = Model.load_from_checkpoint("example.ckpt")
    result = trainer_clone.test(model_clone,data_mnist.test_dataloader())
    print(result)



